// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "src/transform/loop_to_for_loop.h"

#include "src/ast/break_statement.h"
#include "src/ast/for_loop_statement.h"
#include "src/program_builder.h"
#include "src/sem/block_statement.h"
#include "src/sem/function.h"
#include "src/sem/statement.h"
#include "src/sem/variable.h"
#include "src/utils/scoped_assignment.h"

TINT_INSTANTIATE_TYPEINFO(tint::transform::LoopToForLoop);

namespace tint {
namespace transform {
namespace {

bool IsBlockWithSingleBreak(const ast::BlockStatement* block) {
  if (block->statements.size() != 1) {
    return false;
  }
  return block->statements[0]->Is<ast::BreakStatement>();
}

bool IsVarUsedByStmt(const sem::Info& sem,
                     const ast::Variable* var,
                     const ast::Statement* stmt) {
  auto* var_sem = sem.Get(var);
  for (auto* user : var_sem->Users()) {
    if (auto* s = user->Stmt()) {
      if (s->Declaration() == stmt) {
        return true;
      }
    }
  }
  return false;
}

}  // namespace

LoopToForLoop::LoopToForLoop() = default;

LoopToForLoop::~LoopToForLoop() = default;

void LoopToForLoop::Run(CloneContext& ctx, const DataMap&, DataMap&) {
  ctx.ReplaceAll([&](const ast::LoopStatement* loop) -> const ast::Statement* {
    // For loop condition is taken from the first statement in the loop.
    // This requires an if-statement with either:
    //  * A true block with no else statements, and the true block contains a
    //    single 'break' statement.
    //  * An empty true block with a single, no-condition else statement
    //    containing a single 'break' statement.
    // Examples:
    //   loop {  if (condition) { break; } ... }
    //   loop {  if (condition) {} else { break; } ... }
    auto& stmts = loop->body->statements;
    if (stmts.empty()) {
      return nullptr;
    }
    auto* if_stmt = stmts[0]->As<ast::IfStatement>();
    if (!if_stmt) {
      return nullptr;
    }

    bool negate_condition = false;
    if (IsBlockWithSingleBreak(if_stmt->body) &&
        if_stmt->else_statements.empty()) {
      negate_condition = true;
    } else if (if_stmt->body->Empty() && if_stmt->else_statements.size() == 1 &&
               if_stmt->else_statements[0]->condition == nullptr &&
               IsBlockWithSingleBreak(if_stmt->else_statements[0]->body)) {
      negate_condition = false;
    } else {
      return nullptr;
    }

    // The continuing block must be empty or contain a single, assignment or
    // function call statement.
    const ast::Statement* continuing = nullptr;
    if (auto* loop_cont = loop->continuing) {
      if (loop_cont->statements.size() != 1) {
        return nullptr;
      }

      continuing = loop_cont->statements[0];
      if (!continuing
               ->IsAnyOf<ast::AssignmentStatement, ast::CallStatement>()) {
        return nullptr;
      }

      // And the continuing statement must not use any of the variables declared
      // in the loop body.
      for (auto* stmt : loop->body->statements) {
        if (auto* var_decl = stmt->As<ast::VariableDeclStatement>()) {
          if (IsVarUsedByStmt(ctx.src->Sem(), var_decl->variable, continuing)) {
            return nullptr;
          }
        }
      }

      continuing = ctx.Clone(continuing);
    }

    auto* condition = ctx.Clone(if_stmt->condition);
    if (negate_condition) {
      condition = ctx.dst->create<ast::UnaryOpExpression>(ast::UnaryOp::kNot,
                                                          condition);
    }

    ast::Statement* initializer = nullptr;

    ctx.Remove(loop->body->statements, if_stmt);
    auto* body = ctx.Clone(loop->body);
    return ctx.dst->create<ast::ForLoopStatement>(initializer, condition,
                                                  continuing, body);
  });

  ctx.Clone();
}

}  // namespace transform
}  // namespace tint
